Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Support multi-GPU training via accelerate #5

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

bclavie
Copy link
Contributor

@bclavie bclavie commented Dec 22, 2023

Hey! Great work on the library. I've been playing with it and ran into a few issues with in-place operations when trying to train on multiple GPUs:

  • BERT as implemented on HuggingFace has a known issue (this was really annoying to find) where it needs explicit position_ids to run on multi-gpu
  • The scatter_() call to update activations was a bit finnicky

Setting device this way also really doesn't play nice with the default tokeniser export, so there's a workaround to export the files individually rather than risky JSON decoding.

I've also added a doc page to show how simple it is to parallelise training with just those few changes and some very slightly code modifications in a trading script.

@raphaelsty
Copy link
Owner

raphaelsty commented Dec 22, 2023

Congratulation for this amazing work @bclavie 🤩,

Thank you also for the documentation with the DataLoader.

I'll run your branch in the following days to make sure everything run smoothly and then merge and release a new version.

@raphaelsty raphaelsty self-requested a review December 22, 2023 13:44
@raphaelsty raphaelsty added the enhancement New feature or request label Dec 22, 2023
@bclavie
Copy link
Contributor Author

bclavie commented Dec 22, 2023

Thank you! Please do let me know if you run into any issues -- things are training fine right now but I'm using a pretty weird setup so there might still be some issues.

Thank you also for the documentation with the DataLoader.

To be fair there's no code there at the moment, but I'm happy to update with mock data in a bit if you think it'd be useful!

@raphaelsty
Copy link
Owner

raphaelsty commented Dec 22, 2023

I don't have multiples GPUs (not even once) at home so I cannot mimic your environment.

I propose to add the accelerate attribute to all the models. If set to false it will call the tokenizer.encode_batch method and otherwise it will call your encoding procedure. I did this because using position_ids raise an error with distilbert but it work fine with sentence-transformers: neural_cherche/models/base.py

I also updated the documentation a bit in order to show how to create a dataset.

All tests pass locally with the code from your branch and my updates, feel free to copy paste the code I commented.

Also what version of transformers and accelerator are you using ?

@bclavie
Copy link
Contributor Author

bclavie commented Dec 23, 2023

All tests pass locally with the code from your branch and my updates, feel free to copy paste the code I commented.

Hey, did you submit the comments? I can't see the suggested code anywhere, though it might be me being holiday-tired...

Thank you for taking the time to look at this and improving it! I'm running transformers==4.36.2 and accelerate==0.25.0

I've ran some more experiments, and for full disclosure so far:

  • Results when training SparseEmbed on a single-GPU and multi-GPU seem identical
  • It doesn't work as smoothly for ColBERT as for SparseEmbed, there's probably an issue there at the moment.
  • My position_ids workaround fixes things for BERT-based model, but currently XLMRoBERTa (and assumedly RoBERTa itself) cause the same crash as with BERT pre-modification (in-place operation modifying tensors on the wrong device)

My feeling is that it might be actually be unsafe to merge as a "mature" feature at this stage, but doing so and labelling it experimental support could be useful?

(as for neural-cherche itself, I really like the lightweight-ness of the lib, but currently I'm running into some issues where my models end up stuck in some kind of "compressed similarity" land and hard negatives are always extremely close to positives in similarity, which doesn't happen with the main ColBERT-codebase -- I'm training a ColBERT from scratch and will try to diagnose once I have more time!)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Multi-GPU

Neural-Cherche is compatible with multiples GPUs training using [Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator). We can train every models of neural-cherche using GPUs. Here is a tutorial.

```python
import torch
from accelerate import Accelerator
from datasets import Dataset
from torch.utils.data import DataLoader

from neural_cherche import models, train

if __name__ == "__main__":
    # We will need to wrap your training loop in a function to avoid multiprocessing issues.
    accelerator = Accelerator()
    save_each_epoch = True

    model = models.SparseEmbed(
        model_name_or_path="distilbert-base-uncased",
        accelerate=True,
        device=accelerator.device,
    ).to(accelerator.device)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

    # Dataset creation using HuggingFace Datasets library.
    dataset = Dataset.from_dict(
        {
            "anchors": ["anchor 1", "anchor 2", "anchor 3", "anchor 4"],
            "positives": ["positive 1", "positive 2", "positive 3", "positive 4"],
            "negatives": ["negative 1", "negative 2", "negative 3", "negative 4"],
        }
    )

    # Convert your dataset to a DataLoader.
    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Wrap model, optimizer, and dataloader in accelerator.
    model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

    for epoch in range(2):
        for batch in enumerate(data_loader):
            # Batch is a triple like (anchors, positives, negatives)
            anchors, positives, negatives = (
                batch["anchors"],
                batch["positives"],
                batch["negatives"],
            )

            loss = train.train_sparse_embed(
                model=model,
                optimizer=optimizer,
                anchor=anchors,
                positive=positives,
                negative=negatives,
                threshold_flops=30,
                accelerator=accelerator,
            )

        if accelerator.is_main_process and save_each_epoch:
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(
                "checkpoint/epoch" + str(epoch),
            )

    # Save at the end of the training loop
    # We check to make sure that only the main process will export the model
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained("checkpoint")

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a clear example on how to create the dataset using HuggingFace Datasets

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got some troubles with position_ids extra parameters with DistilBERT pre-trained checkpoint but not with all-mpnet-base-v2 pre-trained checkpoint so I think it would be cool to keep the legacy code and add an accelerate attribute to models.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os
from abc import ABC, abstractmethod

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForMaskedLM, AutoTokenizer


class Base(ABC, torch.nn.Module):
    """Base class from which all models inherit.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name.
    device
        Device to use for the model. CPU or CUDA.
    extra_files_to_load
        List of extra files to load.
    accelerate
        Use HuggingFace Accelerate.
    kwargs
        Additional parameters to the model.
    """

    def __init__(
        self,
        model_name_or_path: str,
        device: str = None,
        extra_files_to_load: list[str] = [],
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the model."""
        super(Base, self).__init__()

        if device is not None:
            self.device = device

        elif torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.accelerate = accelerate

        os.environ["TRANSFORMERS_CACHE"] = "."
        self.model = AutoModelForMaskedLM.from_pretrained(
            model_name_or_path, cache_dir="./", **kwargs
        ).to(self.device)

        # Download linear layer if exists
        for file in extra_files_to_load:
            try:
                _ = hf_hub_download(model_name_or_path, filename=file, cache_dir=".")
            except:
                pass

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, device=self.device, cache_dir="./", **kwargs
        )

        self.model.config.output_hidden_states = True

        if os.path.exists(model_name_or_path):
            # Local checkpoint
            self.model_folder = model_name_or_path
        else:
            # HuggingFace checkpoint
            model_folder = os.path.join(
                f"models--{model_name_or_path}".replace("/", "--"), "snapshots"
            )
            snapshot = os.listdir(model_folder)[-1]
            self.model_folder = os.path.join(model_folder, snapshot)

        self.query_pad_token = self.tokenizer.mask_token
        self.original_pad_token = self.tokenizer.pad_token

    def _encode_accelerate(self, texts: list[str], **kwargs) -> tuple[torch.Tensor]:
        """Encode sentences with multiples gpus.

        Parameters
        ----------
        texts
            List of sentences to encode.

        References
        ----------
        [Accelerate issue.](https://github.com/huggingface/accelerate/issues/97)
        """
        encoded_input = self.tokenizer(texts, return_tensors="pt", **kwargs).to(
            self.device
        )

        position_ids = (
            torch.arange(0, encoded_input["input_ids"].size(1))
            .expand((len(texts), -1))
            .to(self.device)
        )

        output = self.model(**encoded_input, position_ids=position_ids)
        return output.logits, output.hidden_states[-1]

    def _encode(self, texts: list[str], **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode sentences.

        Parameters
        ----------
        texts
            List of sentences to encode.
        """
        if self.accelerate:
            return self._encode_accelerate(texts, **kwargs)

        encoded_input = self.tokenizer.batch_encode_plus(
            texts, return_tensors="pt", **kwargs
        )

        if self.device != "cpu":
            encoded_input = {
                key: value.to(self.device) for key, value in encoded_input.items()
            }

        output = self.model(**encoded_input)
        return output.logits, output.hidden_states[-1]

    @abstractmethod
    def forward(self, *args, **kwargs):
        """Pytorch forward method."""
        pass

    @abstractmethod
    def encode(self, *args, **kwargs):
        """Encode documents."""
        pass

    @abstractmethod
    def scores(self, *args, **kwars):
        """Compute scores."""
        pass

    @abstractmethod
    def save_pretrained(self, path: str):
        """Save model the model."""
        pass

    def save_tokenizer_accelerate(self, path: str) -> None:
        """Save tokenizer when using accelerate."""
        tokenizer_config = {
            k: v for k, v in self.tokenizer.__dict__.items() if k != "device"
        }
        tokenizer_config_file = os.path.join(path, "tokenizer_config.json")
        with open(tokenizer_config_file, "w", encoding="utf-8") as file:
            json.dump(tokenizer_config, file, ensure_ascii=False, indent=4)

        # dump vocab
        self.tokenizer.save_vocabulary(path)

        # save special tokens
        special_tokens_file = os.path.join(path, "special_tokens_map.json")
        with open(special_tokens_file, "w", encoding="utf-8") as file:
            json.dump(
                self.tokenizer.special_tokens_map,
                file,
                ensure_ascii=False,
                indent=4,
            )

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the base class updated with a new save_tokenizer_accelerate and accelerate attribute

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os

import torch

from .. import utils
from .base import Base

__all__ = ["ColBERT"]


class ColBERT(Base):
    """ColBERT model.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name.
    embedding_size
        Size of the embeddings in output of ColBERT model.
    device
        Device to use for the model. CPU or CUDA.
    accelerate
        Use HuggingFace Accelerate.
    kwargs
        Additional parameters to the SentenceTransformer model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> queries = ["Berlin", "Paris", "London"]

    >>> documents = [
    ...     "Berlin is the capital of Germany",
    ...     "Paris is the capital of France and France is in Europe",
    ...     "London is the capital of England",
    ... ]

    >>> encoder = models.ColBERT(
    ...     model_name_or_path="sentence-transformers/all-mpnet-base-v2",
    ...     embedding_size=128,
    ...     max_length_query=32,
    ...     max_length_document=350,
    ... )

    >>> scores = encoder.scores(
    ...    queries=queries,
    ...    documents=documents,
    ... )

    >>> scores
    tensor([22.9325, 19.8296, 20.8019])

    >>> _ = encoder.save_pretrained("checkpoint", accelerate=False)

    >>> encoder = models.ColBERT(
    ...     model_name_or_path="checkpoint",
    ...     embedding_size=64,
    ...     device="cpu",
    ... )

    >>> scores = encoder.scores(
    ...    queries=queries,
    ...    documents=documents,
    ... )

    >>> scores
    tensor([22.9325, 19.8296, 20.8019])

    >>> embeddings = encoder(
    ...     texts=queries,
    ...     query_mode=True
    ... )

    >>> embeddings["embeddings"].shape
    torch.Size([3, 32, 128])

    >>> embeddings = encoder(
    ...     texts=queries,
    ...     query_mode=False
    ... )

    >>> embeddings["embeddings"].shape
    torch.Size([3, 350, 128])

    """

    def __init__(
        self,
        model_name_or_path: str,
        embedding_size: int = 128,
        device: str = None,
        max_length_query: int = 32,
        max_length_document: int = 350,
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the model."""
        super(ColBERT, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=["linear.pt", "metadata.json"],
            accelerate=accelerate,
            **kwargs,
        )

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document
        self.embedding_size = embedding_size

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            linear = torch.load(
                os.path.join(self.model_folder, "linear.pt"), map_location=self.device
            )
            self.embedding_size = linear["weight"].shape[0]
            in_features = linear["weight"].shape[1]
        else:
            with torch.no_grad():
                _, embeddings = self._encode(texts=["test"])
                in_features = embeddings.shape[2]

        self.linear = torch.nn.Linear(
            in_features=in_features,
            out_features=self.embedding_size,
            bias=False,
            device=self.device,
        )

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as f:
                metadata = json.load(f)
            self.max_length_document = metadata["max_length_document"]
            self.max_length_query = metadata["max_length_query"]

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            self.linear.load_state_dict(linear)

    def encode(
        self,
        texts: list[str],
        truncation: bool = True,
        add_special_tokens: bool = False,
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Encode documents

        Parameters
        ----------
        texts
            List of sentences to encode.
        truncation
            Truncate the inputs.
        add_special_tokens
            Add special tokens.
        max_length
            Maximum length of the inputs.
        """
        with torch.no_grad():
            embeddings = self(
                texts=texts,
                truncation=truncation,
                add_special_tokens=add_special_tokens,
                query_mode=query_mode,
                **kwargs,
            )
        return embeddings

    def forward(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of sentences to encode.
        query_mode
            Wether to encode query or not.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        kwargs = {
            "truncation": True,
            "padding": "max_length",
            "max_length": self.max_length_query
            if query_mode
            else self.max_length_document,
            "add_special_tokens": True,
            **kwargs,
        }

        _, embeddings = self._encode(texts=texts, **kwargs)

        return {
            "embeddings": torch.nn.functional.normalize(
                self.linear(embeddings), p=2, dim=2
            )
        }

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 2,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Score queries and documents.

        Parameters
        ----------
        queries
            List of queries.
        documents
            List of documents.
        batch_size
            Batch size.
        truncation
            Truncate the inputs.
        add_special_tokens
            Add special tokens.
        tqdm_bar
            Show tqdm bar.
        """
        list_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                texts=batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                texts=batch_documents,
                query_mode=False,
                **kwargs,
            )

            late_interactions = torch.einsum(
                "bsh,bth->bst",
                queries_embeddings["embeddings"],
                documents_embeddings["embeddings"],
            )

            late_interactions = torch.max(late_interactions, axis=2).values.sum(axis=1)

            list_scores.append(late_interactions)

        return torch.cat(list_scores, dim=0)

    def save_pretrained(self, path: str) -> "ColBERT":
        """Save model the model.

        Parameters
        ----------
        path
            Path to save the model.
        """
        self.model.save_pretrained(path)
        torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
        self.tokenizer.pad_token = self.original_pad_token
        with open(os.path.join(path, "metadata.json"), "w") as f:
            json.dump(
                {
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                f,
            )
        if self.accelerate:
            self.save_tokenizer_accelerate(path=path)
        else:
            self.tokenizer.save_pretrained(path)
        return self

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Colbert with the call to save_tokenizer_accelerate parent class :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os

import torch

from .. import utils

__all__ = ["SparseEmbed"]

from .splade import Splade


class SparseEmbed(Splade):
    """SparseEmbed model.

    Parameters
    ----------
    model_name_or_path
        Path to the model or the model name. It should be a SentenceTransformer model.
    embedding_size
        Size of the embeddings in output of SparsEmbed model.
    kwargs
        Additional parameters to the pre-trained model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> device = "mps"

    >>> model = models.SparseEmbed(
    ...     model_name_or_path="distilbert-base-uncased",
    ...     device=device,
    ... )

    >>> queries_embeddings = model.encode(
    ...     ["Sports", "Music"],
    ... )

    >>> queries_embeddings["activations"].shape
    torch.Size([2, 128])

    >>> queries_embeddings["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> queries_embeddings["embeddings"].shape
    torch.Size([2, 128, 128])

    >>> documents_embeddings = model.encode(
    ...    ["Music is great.", "Sports is great."],
    ...    query_mode=False,
    ... )

    >>> documents_embeddings["activations"].shape
    torch.Size([2, 256])

    >>> documents_embeddings["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> documents_embeddings["embeddings"].shape
    torch.Size([2, 256, 128])

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1,
    ... )
    tensor([64.2330, 54.0180], device='mps:0')

    >>> _ = model.save_pretrained("checkpoint")

    >>> model = models.SparseEmbed(
    ...     model_name_or_path="checkpoint",
    ...     device="cpu",
    ... )

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=2,
    ... )
    tensor([64.2330, 54.0180])

    References
    ----------
    1. [SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://dl.acm.org/doi/pdf/10.1145/3539618.3592065)

    """

    def __init__(
        self,
        model_name_or_path: str = None,
        embedding_size: int = 128,
        max_length_query: int = 128,
        max_length_document: int = 256,
        device: str = None,
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        super(SparseEmbed, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=["linear.pt", "metadata.json"],
            accelerate=accelerate,
            **kwargs,
        )

        self.embedding_size = embedding_size

        self.softmax = torch.nn.Softmax(dim=2).to(self.device)

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            linear = torch.load(
                os.path.join(self.model_folder, "linear.pt"), map_location=self.device
            )
            self.embedding_size = linear["weight"].shape[0]
            in_features = linear["weight"].shape[1]
        else:
            with torch.no_grad():
                _, embeddings = self._encode(texts=["test"])
                in_features = embeddings.shape[2]

        self.linear = torch.nn.Linear(
            in_features=in_features,
            out_features=self.embedding_size,
            bias=False,
            device=self.device,
        )

        if os.path.exists(os.path.join(self.model_folder, "linear.pt")):
            self.linear.load_state_dict(linear)

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
                metadata = json.load(file)

            max_length_query = metadata["max_length_query"]
            max_length_document = metadata["max_length_document"]

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document

    def forward(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of documents to encode.
        query_mode
            Whether to encode queries or documents.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        k_tokens = self.max_length_query if query_mode else self.max_length_document

        logits, embeddings = self._encode(
            texts=texts,
            truncation=True,
            padding="max_length",
            max_length=k_tokens,
            add_special_tokens=True,
            **kwargs,
        )

        activations = self._update_activations(
            **self._get_activation(logits=logits),
            k_tokens=k_tokens,
        )

        attention = self._get_attention(
            logits=logits,
            activations=activations["activations"],
        )

        embeddings = torch.bmm(
            attention,
            embeddings,
        )

        return {
            "embeddings": self.relu(self.linear(embeddings)),
            "sparse_activations": activations["sparse_activations"],
            "activations": activations["activations"],
        }

    def _get_attention(
        self, logits: torch.Tensor, activations: torch.Tensor
    ) -> torch.Tensor:
        """Extract attention scores from MLM logits based on activated tokens."""
        attention = logits.gather(
            dim=2,
            index=torch.stack(
                [
                    torch.stack([token for _ in range(logits.shape[1])])
                    for token in activations
                ]
            ),
        )

        return self.softmax(attention)

    def save_pretrained(
        self,
        path: str,
    ):
        """Save model the model."""
        self.model.save_pretrained(path)
        self.tokenizer.pad_token = self.original_pad_token

        if self.accelerate:
            self.save_tokenizer_accelerate(path)
        else:
            self.tokenizer.save_pretrained(path)
        torch.save(self.linear.state_dict(), os.path.join(path, "linear.pt"))
        with open(os.path.join(path, "metadata.json"), "w") as file:
            json.dump(
                fp=file,
                obj={
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                indent=4,
            )

        return self

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 32,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute similarity scores between queries and documents."""
        dense_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                texts=batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                texts=batch_documents,
                query_mode=False,
                **kwargs,
            )

            dense_scores.append(
                utils.pairs_dense_scores(
                    queries_activations=queries_embeddings["activations"],
                    documents_activations=documents_embeddings["activations"],
                    queries_embeddings=queries_embeddings["embeddings"],
                    documents_embeddings=documents_embeddings["embeddings"],
                )
            )

        return torch.cat(dense_scores, dim=0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseEmbed with the call to save_tokenizer_accelerate parent class :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import json
import os
import string

import torch

from .. import utils
from .base import Base

__all__ = ["Splade"]


class Splade(Base):
    """SpladeV1 model.

    Parameters
    ----------
    tokenizer
        HuggingFace Tokenizer.
    model
        HuggingFace AutoModelForMaskedLM.
    kwargs
        Additional parameters to the SentenceTransformer model.

    Examples
    --------
    >>> from neural_cherche import models
    >>> import torch

    >>> _ = torch.manual_seed(42)

    >>> model = models.Splade(
    ...     model_name_or_path="distilbert-base-uncased",
    ...     device="mps",
    ... )

    >>> queries_activations = model.encode(
    ...     ["Sports", "Music"],
    ... )

    >>> documents_activations = model.encode(
    ...    ["Music is great.", "Sports is great."],
    ...    query_mode=False,
    ... )

    >>> queries_activations["sparse_activations"].shape
    torch.Size([2, 30522])

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1
    ... )
    tensor([318.1384, 271.8006], device='mps:0')

    >>> _ = model.save_pretrained("checkpoint")

    >>> model = models.Splade(
    ...     model_name_or_path="checkpoint",
    ...     device="mps",
    ... )

    >>> model.scores(
    ...     queries=["Sports", "Music"],
    ...     documents=["Sports is great.", "Music is great."],
    ...     batch_size=1
    ... )
    tensor([318.1384, 271.8006], device='mps:0')

    References
    ----------
    1. [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)

    """

    def __init__(
        self,
        model_name_or_path: str = None,
        device: str = None,
        max_length_query: int = 128,
        max_length_document: int = 256,
        extra_files_to_load: list[str] = ["metadata.json"],
        accelerate: bool = False,
        **kwargs,
    ) -> None:
        super(Splade, self).__init__(
            model_name_or_path=model_name_or_path,
            device=device,
            extra_files_to_load=extra_files_to_load,
            accelerate=accelerate,
            **kwargs,
        )

        self.relu = torch.nn.ReLU().to(self.device)

        if os.path.exists(os.path.join(self.model_folder, "metadata.json")):
            with open(os.path.join(self.model_folder, "metadata.json"), "r") as file:
                metadata = json.load(file)

            max_length_query = metadata["max_length_query"]
            max_length_document = metadata["max_length_document"]

        self.max_length_query = max_length_query
        self.max_length_document = max_length_document

    def encode(
        self,
        texts: list[str],
        query_mode: bool = True,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Encode documents

        Parameters
        ----------
        texts
            List of documents to encode.
        truncation
            Whether to truncate the documents.
        padding
            Whether to pad the documents.
        max_length
            Maximum length of the documents.
        """
        with torch.no_grad():
            return self(
                texts=texts,
                query_mode=query_mode,
                **kwargs,
            )

    def decode(
        self,
        sparse_activations: torch.Tensor,
        clean_up_tokenization_spaces: bool = False,
        skip_special_tokens: bool = True,
        k_tokens: int = 96,
    ) -> list[str]:
        """Decode activated tokens ids where activated value > 0.

        Parameters
        ----------
        sparse_activations
            Activated tokens.
        clean_up_tokenization_spaces
            Whether to clean up the tokenization spaces.
        skip_special_tokens
            Whether to skip special tokens.
        k_tokens
            Number of tokens to keep.
        """
        activations = self._filter_activations(
            sparse_activations=sparse_activations, k_tokens=k_tokens
        )

        # Decode
        return [
            " ".join(
                activation.translate(str.maketrans("", "", string.punctuation)).split()
            )
            for activation in self.tokenizer.batch_decode(
                activations,
                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                skip_special_tokens=skip_special_tokens,
            )
        ]

    def forward(
        self,
        texts: list[str],
        query_mode: bool,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Pytorch forward method.

        Parameters
        ----------
        texts
            List of documents to encode.
        query_mode
            Whether to encode queries or documents.
        """
        suffix = "[Q] " if query_mode else "[D] "

        texts = [suffix + text for text in texts]

        self.tokenizer.pad_token = (
            self.query_pad_token if query_mode else self.original_pad_token
        )

        k_tokens = self.max_length_query if query_mode else self.max_length_document

        logits, _ = self._encode(
            texts=texts,
            truncation=True,
            padding="max_length",
            max_length=k_tokens,
            add_special_tokens=True,
            **kwargs,
        )

        activations = self._get_activation(logits=logits)

        activations = self._update_activations(
            **activations,
            k_tokens=k_tokens,
        )

        return {"sparse_activations": activations["sparse_activations"]}

    def save_pretrained(
        self,
        path: str,
    ):
        """Save model the model.

        Parameters
        ----------
        path
            Path to save the model.

        """
        self.model.save_pretrained(path)
        self.tokenizer.pad_token = self.original_pad_token

        if self.accelerate:
            self.save_tokenizer_accelerate(path)
        else:
            self.tokenizer.save_pretrained(path)

        with open(os.path.join(path, "metadata.json"), "w") as file:
            json.dump(
                fp=file,
                obj={
                    "max_length_query": self.max_length_query,
                    "max_length_document": self.max_length_document,
                },
                indent=4,
            )

        return self

    def scores(
        self,
        queries: list[str],
        documents: list[str],
        batch_size: int = 32,
        tqdm_bar: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Compute similarity scores between queries and documents.

        Parameters
        ----------
        queries
            List of queries.
        documents
            List of documents.
        batch_size
            Batch size.
        tqdm_bar
            Show a progress bar.
        """
        sparse_scores = []

        for batch_queries, batch_documents in zip(
            utils.batchify(
                X=queries,
                batch_size=batch_size,
                desc="Computing scores.",
                tqdm_bar=tqdm_bar,
            ),
            utils.batchify(X=documents, batch_size=batch_size, tqdm_bar=False),
        ):
            queries_embeddings = self.encode(
                batch_queries,
                query_mode=True,
                **kwargs,
            )

            documents_embeddings = self.encode(
                batch_documents,
                query_mode=False,
                **kwargs,
            )

            sparse_scores.append(
                torch.sum(
                    queries_embeddings["sparse_activations"]
                    * documents_embeddings["sparse_activations"],
                    axis=1,
                )
            )

        return torch.cat(sparse_scores, dim=0)

    def _get_activation(self, logits: torch.Tensor) -> dict[str, torch.Tensor]:
        """Returns activated tokens."""
        return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}

    def _filter_activations(
        self, sparse_activations: torch.Tensor, k_tokens: int
    ) -> list[torch.Tensor]:
        """Among the set of activations, select the ones with a score > 0."""
        scores, activations = torch.topk(input=sparse_activations, k=k_tokens, dim=-1)
        return [
            torch.index_select(
                activation, dim=-1, index=torch.nonzero(score, as_tuple=True)[0]
            )
            for score, activation in zip(scores, activations)
        ]

    def _update_activations(
        self, sparse_activations: torch.Tensor, k_tokens: int
    ) -> torch.Tensor:
        """Returns activated tokens."""
        activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices
        zero_tensor = torch.zeros_like(sparse_activations, dtype=int)
        updated_sparse_activations = sparse_activations * zero_tensor.scatter(
            dim=1, index=activations.long(), value=1
        )

        return {
            "activations": activations,
            "sparse_activations": updated_sparse_activations,
        }

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splade with the call to save_tokenizer_accelerate parent class :)

@raphaelsty
Copy link
Owner

Hey, did you submit the comments? I can't see the suggested code anywhere, though it might be me being holiday-tired...

Ahah missed this, sorry.

> (as for neural-cherche itself, I really like the lightweight-ness of the lib, but currently I'm running into some issues where my models end up stuck in some kind of "compressed similarity" land and hard negatives are always extremely close to positives in similarity, which doesn't happen with the main ColBERT-codebase -- I'm training a ColBERT from scratch and will try to diagnose once I have more time!)

It could come from the loss function which is quite simple? Would love to get your feedback on this if you find anything.

Overall, I think it's fine to push your work on Master if we use the flag self.accelerate, It will be a first step through the acceleration of the lib over multiple gpus ! :)

@bclavie
Copy link
Contributor Author

bclavie commented Dec 23, 2023

Ahah missed this, sorry.

No worries, I've applied the changes 1:1, except for the tutorial page (added that support is partial/in-progress, so people don't get the impression it's fully supported yet!)

It could come from the loss function which is quite simple? Would love to get your feedback on this if you find anything.

I think that's probably it... I'll definitely try and figure exactly what component has the biggest impact once I've got some more time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants